In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import spectral_norm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset
import matplotlib.pyplot as plt
import numpy as np
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import glob
import matplotlib.image as mpimg
In [ ]:
# ----------------------------------------------------------
# Device
# ----------------------------------------------------------
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device {device}')
Using device mps
In [ ]:
# ----------------------------------------------------------
# Hyperparameters (Complete SAGAN)
# ----------------------------------------------------------
EPOCHS = 550
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS_IMG = 3
LATENT_DIM = 128
EMBED_DIM = 50
GEN_LR = 1e-4
DISC_LR = 4e-4
BETA1, BETA2 = 0.0, 0.9
CHECKPOINT_EVERY = 20
AUTOMOBILE_CLASS_IDX = 1
In [ ]:
# ----------------------------------------------------------
# Self-Attention Block
# ----------------------------------------------------------
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.query = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
        self.key   = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
        self.value = spectral_norm(nn.Conv2d(in_channels, in_channels, 1))
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        b, c, h, w = x.size()
        query_out = self.query(x).view(b, -1, w*h)                 # (b, c//8, h*w)
        key_out   = self.key(x).view(b, -1, w*h)                   # (b, c//8, h*w)
        attn      = torch.bmm(query_out.permute(0, 2, 1), key_out) # (b, h*w, h*w)
        attn      = torch.softmax(attn, dim=-1)
        value_out = self.value(x).view(b, c, w*h)                  # (b, c, h*w)
        out       = torch.bmm(value_out, attn.permute(0, 2, 1))    # (b, c, h*w)
        out       = out.view(b, c, h, w)
        return self.gamma * out + x
In [ ]:
# ----------------------------------------------------------
# CIFAR-10 Data Loading
# ----------------------------------------------------------
transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
automobile_indices = [i for i, (_, label) in enumerate(trainset) if label == AUTOMOBILE_CLASS_IDX]
automobile_dataset = Subset(trainset, automobile_indices)

# Create dataloader with only automobile images
trainloader = DataLoader(
    automobile_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)
In [ ]:
# ----------------------------------------------------------
# Generator (SAGAN with Spectral Norm)
# ----------------------------------------------------------
class Generator(nn.Module):
    def __init__(self, latent_dim, embed_dim, num_classes=10):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)

        self.init_fc = nn.Sequential(
            spectral_norm(nn.Linear(latent_dim + embed_dim, 4*4*512)),
            nn.BatchNorm1d(4*4*512),
            nn.ReLU(True)
        )

        self.conv_blocks = nn.Sequential(
            spectral_norm(nn.ConvTranspose2d(512, 256, 4, 2, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            SelfAttention(256),
            spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            spectral_norm(nn.ConvTranspose2d(128, CHANNELS_IMG, 4, 2, 1)),
            nn.Tanh()
        )

    def forward(self, z, labels):
        emb = self.label_emb(labels)
        x = torch.cat([z, emb], dim=1)
        x = self.init_fc(x).view(-1, 512, 4, 4)
        return self.conv_blocks(x)
In [ ]:
# ----------------------------------------------------------
# Discriminator (SAGAN with Spectral Norm)
# ----------------------------------------------------------
class Discriminator(nn.Module):
    def __init__(self, embed_dim, num_classes=10):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, embed_dim)

        self.conv_blocks = nn.Sequential(
            spectral_norm(nn.Conv2d(CHANNELS_IMG, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(128),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = spectral_norm(nn.Linear(512*2*2 + embed_dim, 1))

    def forward(self, x, labels):
        bsz = x.size(0)
        emb = self.label_emb(labels)
        features = self.conv_blocks(x).view(bsz, -1)
        combined = torch.cat([features, emb], dim=1)
        return self.fc(combined)
In [ ]:
# ----------------------------------------------------------
# Initialize Model, Loss, Optimizers
# ----------------------------------------------------------
gen = Generator(LATENT_DIM, EMBED_DIM).to(device)
disc = Discriminator(EMBED_DIM).to(device)
criterion = nn.BCEWithLogitsLoss()

opt_gen = optim.Adam(gen.parameters(), lr=GEN_LR, betas=(BETA1, BETA2))
opt_disc = optim.Adam(disc.parameters(), lr=DISC_LR, betas=(BETA1, BETA2))

checkpoint_path = "adl_part_4.pt"
start_epoch = 1
In [ ]:
# ----------------------------------------------------------
# Check for Existing Checkpoint
# ----------------------------------------------------------
if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    gen.load_state_dict(ckpt["gen_state_dict"])
    disc.load_state_dict(ckpt["disc_state_dict"])
    opt_gen.load_state_dict(ckpt["opt_gen_state_dict"])
    opt_disc.load_state_dict(ckpt["opt_disc_state_dict"])
    start_epoch = ckpt["epoch"] + 1
In [ ]:
# ----------------------------------------------------------
# Utility: Generate & Show 10 Samples
# ----------------------------------------------------------
def generate_and_show_samples(epoch):
    gen.eval()
    with torch.no_grad():
        z = torch.randn(10, LATENT_DIM, device=device)
        labels = torch.full((10,), AUTOMOBILE_CLASS_IDX, dtype=torch.long, device=device)
        samples = gen(z, labels).cpu()
    samples = (samples + 1) / 2.0
    fig, axes = plt.subplots(1, 10, figsize=(22, 2.4))
    for i in range(10):
        img = samples[i].permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.suptitle(f"Epoch {epoch}: SAGAN Samples (Automobile)", fontsize=14)
    plt.savefig(f'task4/automobile_gan_losses_{epoch}.png')
    plt.show()
    gen.train()
In [ ]:
# ----------------------------------------------------------
# Compute IS & FID
# ----------------------------------------------------------
def compute_is_fid(generator, loader, n_samples=2000):
    is_metric = InceptionScore().to("cpu")
    fid_metric = FrechetInceptionDistance().to("cpu")
    generator.eval()

    real_count = 0
    for real_imgs, _ in loader:
        real_imgs = real_imgs.to(device)
        real_imgs_uint8 = (((real_imgs * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
        fid_metric.update(real_imgs_uint8, real=True)
        real_count += real_imgs.size(0)
        if real_count >= n_samples:
            break

    fake_count = 0
    while fake_count < n_samples:
        z = torch.randn(BATCH_SIZE, LATENT_DIM, device=device)
        labels = torch.randint(0, 10, (BATCH_SIZE,), dtype=torch.long, device=device)
        with torch.no_grad():
            fake_out = generator(z, labels)
        fake_out_uint8 = (((fake_out * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
        is_metric.update(fake_out_uint8)
        fid_metric.update(fake_out_uint8, real=False)
        fake_count += BATCH_SIZE

    inception_score = is_metric.compute()  # (mean, std)
    fid_score = fid_metric.compute()
    generator.train()
    return inception_score[0].item(), fid_score.item()
In [ ]:
# ----------------------------------------------------------
# Training Loop
# ----------------------------------------------------------
for epoch in range(start_epoch, EPOCHS + 1):
    for _, (real, labels) in enumerate(trainloader):
        real, labels = real.to(device), labels.to(device)
        bsz = real.size(0)

        # --------------------
        # Train Discriminator
        # --------------------
        disc.zero_grad()
        noise = torch.randn(bsz, LATENT_DIM, device=device)
        rand_labels = torch.randint(0, 10, (bsz,), dtype=torch.long, device=device)

        pred_real = disc(real, labels)
        loss_real = criterion(pred_real, torch.ones_like(pred_real))

        fake = gen(noise, rand_labels)
        pred_fake = disc(fake.detach(), rand_labels)
        loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake))

        lossD = loss_real + loss_fake
        lossD.backward()
        opt_disc.step()

        # ----------------
        # Train Generator
        # ----------------
        gen.zero_grad()
        pred_gen = disc(fake, rand_labels)
        lossG = criterion(pred_gen, torch.ones_like(pred_gen))
        lossG.backward()
        opt_gen.step()

    print(f"[Epoch {epoch}/{EPOCHS}]  LossD: {lossD.item():.4f}  LossG: {lossG.item():.4f}")

    if epoch % CHECKPOINT_EVERY == 0:
        save_data = {
            "epoch": epoch,
            "gen_state_dict": gen.state_dict(),
            "disc_state_dict": disc.state_dict(),
            "opt_gen_state_dict": opt_gen.state_dict(),
            "opt_disc_state_dict": opt_disc.state_dict()
        }
        torch.save(save_data, checkpoint_path)
        print(f"[epoch={epoch}]Checkpoint saved: {checkpoint_path}")
        generate_and_show_samples(epoch)
        is_val, fid_val = compute_is_fid(gen, trainloader)
        print(f"==> Epoch {epoch}: Inception Score = {is_val:.4f}, FID = {fid_val:.4f}")

print("Training complete!")
[Epoch 501/550]  LossD: 0.2583  LossG: 8.0771
[Epoch 502/550]  LossD: 0.0396  LossG: 11.6219
[Epoch 503/550]  LossD: 0.2819  LossG: 11.0093
[Epoch 504/550]  LossD: 0.0392  LossG: 8.3315
[Epoch 505/550]  LossD: 0.1530  LossG: 7.2564
[Epoch 506/550]  LossD: 0.2484  LossG: 8.0930
[Epoch 507/550]  LossD: 0.1423  LossG: 6.9417
[Epoch 508/550]  LossD: 0.1764  LossG: 7.5808
[Epoch 509/550]  LossD: 0.2741  LossG: 8.0010
[Epoch 510/550]  LossD: 0.0762  LossG: 8.3668
[Epoch 511/550]  LossD: 0.2209  LossG: 9.2348
[Epoch 512/550]  LossD: 0.0600  LossG: 7.1931
[Epoch 513/550]  LossD: 0.1417  LossG: 10.0495
[Epoch 514/550]  LossD: 0.4248  LossG: 7.6931
[Epoch 515/550]  LossD: 0.1376  LossG: 6.9469
[Epoch 516/550]  LossD: 0.7187  LossG: 15.0068
[Epoch 517/550]  LossD: 0.3110  LossG: 8.2563
[Epoch 518/550]  LossD: 0.0804  LossG: 6.7781
[Epoch 519/550]  LossD: 0.1688  LossG: 7.4880
[Epoch 520/550]  LossD: 0.2271  LossG: 8.3182
[epoch=520]Checkpoint saved: adl_part_4.pt
No description has been provided for this image
/Users/shivamsahil/Downloads/bits/assignments/venv/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)  # noqa: B028
==> Epoch 520: Inception Score = 1.6597, FID = 343.0216
[Epoch 521/550]  LossD: 0.1547  LossG: 8.0558
[Epoch 522/550]  LossD: 0.1291  LossG: 6.8157
[Epoch 523/550]  LossD: 0.1425  LossG: 10.0711
[Epoch 524/550]  LossD: 0.2498  LossG: 7.4953
[Epoch 525/550]  LossD: 0.2674  LossG: 5.8651
[Epoch 526/550]  LossD: 0.3667  LossG: 9.8780
[Epoch 527/550]  LossD: 0.0309  LossG: 7.9693
[Epoch 528/550]  LossD: 0.1494  LossG: 8.0814
[Epoch 529/550]  LossD: 0.2831  LossG: 5.3990
[Epoch 530/550]  LossD: 0.3526  LossG: 7.0946
[Epoch 531/550]  LossD: 0.6136  LossG: 16.5109
[Epoch 532/550]  LossD: 0.0501  LossG: 10.2105
[Epoch 533/550]  LossD: 0.3019  LossG: 6.9905
[Epoch 534/550]  LossD: 0.3236  LossG: 7.8011
[Epoch 535/550]  LossD: 0.4857  LossG: 8.6465
[Epoch 536/550]  LossD: 0.3543  LossG: 8.8010
[Epoch 537/550]  LossD: 0.0894  LossG: 9.2096
[Epoch 538/550]  LossD: 0.3729  LossG: 6.4945
[Epoch 539/550]  LossD: 0.0231  LossG: 8.9839
[Epoch 540/550]  LossD: 0.3395  LossG: 6.9768
[epoch=540]Checkpoint saved: adl_part_4.pt
No description has been provided for this image
==> Epoch 540: Inception Score = 1.6411, FID = 363.5405
[Epoch 541/550]  LossD: 0.1517  LossG: 7.6274
[Epoch 542/550]  LossD: 0.1124  LossG: 15.2578
[Epoch 543/550]  LossD: 0.1488  LossG: 8.4764
[Epoch 544/550]  LossD: 0.3222  LossG: 6.4993
[Epoch 545/550]  LossD: 0.3849  LossG: 11.2570
[Epoch 546/550]  LossD: 0.0509  LossG: 11.5791
[Epoch 547/550]  LossD: 0.4406  LossG: 8.0726
[Epoch 548/550]  LossD: 0.0303  LossG: 7.4664
[Epoch 549/550]  LossD: 0.0429  LossG: 8.0287
[Epoch 550/550]  LossD: 0.2678  LossG: 10.0007
Training complete!
In [ ]:
directory = r'task4'
# Define a custom sort key that extracts the epoch number
def extract_epoch(filename):
    base = os.path.basename(filename)
    try:
        epoch_str = base.split('automobile_gan_losses_')[1].split('.')[0]
        return int(epoch_str)
    except (IndexError, ValueError):
        return float('inf')
png_files = glob.glob(os.path.join(directory, '*.png'))
# Sort the list numerically by epoch number
png_files = sorted(png_files, key=extract_epoch)

# Check if any PNG files are found
if not png_files:
    print("No PNG files found in the directory:", directory)
else:
    n = len(png_files)

    # Increase the figure size to accommodate full screen-like display
    fig, axs = plt.subplots(n, 1, figsize=(22, 2.4 * n))

    # If only one image, wrap axs into a list for consistency
    if n == 1:
        axs = [axs]

    mng = plt.get_current_fig_manager()
    try:
        mng.window.state('zoomed')
    except AttributeError:
        try:
            mng.window.showMaximized()
        except Exception:
            pass  # If it fails, the figure will remain at the set figsize

    # Loop through each file and display the image
    for ax, file in zip(axs, png_files):
        img = mpimg.imread(file)
        ax.imshow(img, aspect='auto')
        ax.axis('off')
        ax.set_title(os.path.basename(file), fontsize=14)

    plt.tight_layout()
    plt.show()
No description has been provided for this image
In [ ]:
# Install necessary packages
!apt-get install texlive texlive-xetex texlive-latex-extra pandoc
!pip install pypandoc

# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

# Copy the notebook to the current directory
!cp 'drive/My Drive/Colab Notebooks/Assignment2_Group75_Task4.ipynb' ./

# Convert the notebook to PDF while keeping the code and output
!jupyter nbconvert --to html "Assignment2_Group75_Task4.ipynb"


# Download the generated PDF
from google.colab import files
files.download('Assignment2_Group75_Task4.html')